import torch

# torch.cat
print("\ncat")
a = torch.zeros((2,4))
b = torch.ones((2,4))
print(a)
print(b)
out = torch.cat((a, b), dim=0)
print(out, out.shape)

# torch.stack
print("\nstack")
a = torch.linspace(1, 6, 6).view(2, 3)
b = torch.linspace(7, 12, 6).view(2, 3)
print(a)
print(b)
out = torch.stack((a, b), dim=1)
print(out, out.shape)